#pragma once

/*
 * anet default interface implementation.
 */
#include <string>
#include <mutex>
#include <atomic>
#include <map>
#include <functional>
#include "anet.hpp"
#include "log.h"
#include "define.hpp"
#include "semaphore.hpp"
#include "connection.hpp"

namespace anet {
	namespace tcp {
		using uint32 = unsigned int;
		using uint16 = unsigned short;
		using uint64 = unsigned long long;
		using ulong  = unsigned long;
		static constexpr uint16 InvalidMsgId = 0xffff;

		// response message
		struct response {
			// message id
			uint16 msgId;

			// message content
			std::string message;

			response() : msgId(InvalidMsgId) {
				message.clear();
			}
			response(uint16 id, const std::string& msg) :
				msgId(id), message(msg) {
			}
		};

		// response info.
		struct responseInfo {
			responseInfo() : sem(0) {}

			// response message
			response resp;

			// semaphore
			utils::CSemaphore sem;
		};

		// tcp head struct which must be entranced by pack(push,1).
    #pragma pack(push,1)
		struct SCommonHead {
			uint32 len;
		};
    #pragma pack(pop)

		// protocol head size
		constexpr int gProto_head_size = sizeof(SCommonHead);

		// message id size
		constexpr int gProto_message_id_size = sizeof(uint16);

		// big codec class.
		// packet format: len(4byte) + body
		class CBigCodec : public ICodec {
		public:
			CBigCodec() = default;
			virtual ~CBigCodec() = default;

		public:
			virtual int parsePacket(const char *data, int len) override {
				if (len < gProto_head_size) {
					return retNotComplete;
				}

				const SCommonHead* pHead = (const SCommonHead*)data;
				auto dataLen = ntohl(pHead->len);
				if (int(dataLen) > gMaxPacketSize) {
					return retError;
				}

				// complete packet check
				if (len >= int(dataLen + gProto_head_size)) {
					return dataLen + gProto_head_size;
				} else {
					return retNotComplete;
				}
			}
		};

		// typedef CBigCodec to CCodec.
		typedef CBigCodec CCodec;

		// little codec.
		class CLittleCodec : public ICodec {
		public:
			CLittleCodec() = default;
			virtual ~CLittleCodec() = default;

		public:
			virtual int parsePacket(const char *data, int len) override {
				if (len < gProto_head_size) {
					return retNotComplete;
				}

				const SCommonHead* pHead = (const SCommonHead*)data;
				auto dataLen = pHead->len;
				// max packet size check.
				if (int(dataLen) > gMaxPacketSize) {
					return retError;
				}

				// complete packet check
				if (len >= int(dataLen + gProto_head_size)) {
					return dataLen + gProto_head_size;
				} else {
					return retNotComplete;
				}
			}
		};

		// === synchronous session === //
		// response map. this class can not be used outside.
		class responseMap final {
		public:
			responseMap() {
				m_responses.clear();
			}
			virtual ~responseMap() {
				m_responses.clear();
			}

		public:
			responseInfo* get(ulong objId) const {
				std::lock_guard<std::mutex> guard(m_mu);
				auto it = m_responses.find(objId);
				if (it == m_responses.end()) {
					return nullptr;
				} else {
					return it->second;
				}
			}
			void remove(ulong objId) {
				std::lock_guard<std::mutex> guard(m_mu);
				m_responses.erase(objId);
			}
			void add(ulong objId, responseInfo *pInfo) {
				std::lock_guard<std::mutex> guard(m_mu);
				m_responses[objId] = pInfo;
			}

		private:
			using objId2MapType = std::map<ulong, responseInfo*>;
			objId2MapType m_responses;
			mutable std::mutex m_mu;
		};


		// message handler for CSyncSession.
		class CSyncSession;
		using SyncMessageProcessFunc = std::function<void(CSyncSession*, ulong, const char*)>;

		//
		// sync session: with CBigCodec codec:
		// len(4) + msgId(2) + message unique id(4) + message data
		//
		class CSyncSession : public ISession {
		public:
			CSyncSession() {
				for (auto i = 0;i < gResponseSize;i++) {
					m_responses[i] = new responseMap();
				}
			}
			virtual ~CSyncSession() {
				for (auto i = 0;i < gResponseSize;i++) {
					if (m_responses[i] != nullptr) {
						delete m_responses[i];
						m_responses[i] = nullptr;
					}
				}
			}

		public:
			virtual void onRecv(const char *data, int len) override {
				assert(len >= gProto_head_size && "data len is error");
				const char *content = data + gProto_head_size;

				// message id
				auto msgId = ntohs(*(uint16*)(content));

				// message unique id.
				ulong msgObjId = ntohl(*(ulong*)(content + gProto_message_id_size));

				// response content
				const char *strResp = content + gProto_message_id_size + sizeof(ulong);
				if (msgObjId != 0) {
					// for synchronous mode
					if (auto resMap = this->find(msgObjId)) {
						if (auto respInfo = resMap->get(msgObjId)) {
							// build response data.
							auto msgLen = int(len - gProto_head_size - gProto_message_id_size - sizeof(ulong));
							auto &&resp = std::string(strResp, msgLen);

							// copy response content.
							respInfo->resp = std::move(response{ msgId, std::move(resp) });

							// signal to notify require() function to get response.
							respInfo->sem.signal();
						}
					}
				} else {
					// for asynchronous mode as the msgObjId == 0.
					if (m_msgHandler != nullptr) {
						m_msgHandler(this, msgObjId, strResp);
					} else {
						LogAdebug("can not find message {} handler", msgId);
					}
				}
			}

			virtual void onTerminate() override {
				m_closed = true;
				LogAinfo("{}:{} disconnected", m_conn->getRemoteAddr(), m_conn->getRemotePort());
			}

			virtual void onConnected(connSharePtr conn) override {
				m_conn = conn;
				m_closed = false;
				LogAinfo("{}:{} connected", conn->getRemoteAddr(), conn->getRemotePort());
			}

			virtual void release() override {
				delete this;
			}

			bool isClosed() const {
				return m_closed;
			}

			// set asynchronous message handler.
			void setMessageHandler(SyncMessageProcessFunc handler) {
				m_msgHandler = std::move(handler);
			}

			// require gets a requirement synchronously.
			response require(uint16 msgId, const std::string &data) {
				if (sizeof(uint16) + sizeof(ulong) + data.length() > gMaxPacketSize) {
					return response{ 0,"" };
				}

				char vecBuff[gMaxPacketSize];
				// message unique id.
				ulong msgObjId = m_nextId++;

				// save to response map.
				responseInfo *waitResp = new responseInfo();
				if (!this->add(msgObjId, waitResp)) {
					delete waitResp;
					return response{ 0,"" };
				}

				// send message to server.
				int allLen = buildProto(msgId, msgObjId, data, vecBuff);
				if (!this->send(&vecBuff[0], allLen)) {
					// remove from response map.
					if (auto resp = this->find(msgObjId)) {
						resp->remove(msgObjId);
					}

					delete waitResp;
					return response{ 0,"" };
				}

				// wait for gTimeout.
				if (!waitResp->sem.wait_for(gTimeout)) {
					// wait time out.
					if (auto resp = this->find(msgObjId)) {
						resp->remove(msgObjId);
					}
					delete waitResp;
					return response{ InvalidMsgId, "" };
				} else {
					// wait successfully.
					response res(waitResp->resp);
					delete waitResp;
					return res;
				}
			}

			// build msgId and msg.
			template<int N>
			int buildProto(uint16 msgId, ulong objId, const std::string &data, char(&buff)[N]) {
				// build header: message length.
				SCommonHead &head = *(SCommonHead*)buff;
				head.len = static_cast<uint32>((htonl(long(data.length() + sizeof(ulong) + sizeof(msgId)))));

				// message id.
				*(uint16*)(buff + gProto_head_size) = htons(msgId);

				// message object id.
				*(ulong*)(buff + gProto_head_size + sizeof(uint16)) = htonl(objId);

				// message content.
				memcpy(buff + gProto_head_size + sizeof(msgId) + sizeof(ulong), data.c_str(), data.length());
				return int(gProto_head_size + sizeof(msgId) + sizeof(ulong) + data.length());
			}

			// send message asynchronously.
			bool Send(uint16 msgId, const std::string &data) {
				if (sizeof(uint16) + sizeof(ulong) + data.length() > gMaxPacketSize) {
					return false;
				}

				char vecBuff[gMaxPacketSize];
				int allLen = buildProto(msgId, 0, data, vecBuff);
				return this->send(&vecBuff[0], allLen);
			}

			// close connection.
			void close() {
				if (isClosed()) return;
				m_conn->close();
			}

		protected:
			responseMap* find(ulong msgObjId) {
				if (isClosed()) {
					return nullptr;
				} else {
					return m_responses[msgObjId%gResponseSize];
				}
			}
			bool add(ulong objId, responseInfo* pInfo) {
				if (isClosed()) {
					return false;
				} else {
					m_responses[objId%gResponseSize]->add(objId, pInfo);
					return true;
				}
			}
			// send binary data
			bool send(const char *msg, size_t len) {
				if (isClosed()) {
					return false;
				}
				// send message(msg,len).
				m_conn->send(msg, len);
				return true;
			}

		protected:
			connSharePtr m_conn{ nullptr };
			std::atomic_bool m_closed{ true };

			// message unique object id generation.
			std::atomic_ulong m_nextId{ 1 };

			// response array map.
			static const int gResponseSize = 16;
			responseMap* m_responses[gResponseSize] = { nullptr };

			// message handler for asynchronous mode if possible.
			SyncMessageProcessFunc m_msgHandler{ nullptr };
			static const int gMaxPacketSize = 10 * 1024;

			// time out value.
			static constexpr std::chrono::milliseconds gTimeout = std::chrono::milliseconds{ 3000 };
		};
		// === synchronous session end ===//


		//=============================
		// CSession message handler.
		class CSession;
		using MessageHandler = std::function<void(CSession*, const char*, int)>;
		using StatusHandler = std::function<void(CSession*, bool isConnected)>;

		// tcp session(asynchronous mode) with CBigCodec codec.
		class CSession final : public ISession {
		public:
			CSession() = default;
			virtual ~CSession() = default;

		public:
			void onMessage(const char *msg, int len) {
				assert(len >= gProto_head_size && "data len is error");

				// msg binary
				auto msgData = msg + gProto_head_size;

				// msg id.
				auto msgId = ntohs(*(uint16*)msgData);

				// msg content
				auto size = len - gProto_head_size - gProto_message_id_size;
				auto &&body = std::string(msg + gProto_head_size + gProto_message_id_size, size);
				LogAdebug("obj:{},msg id:{},msg:{}", this, msgId, body.c_str());
				
				// send message back.
				this->send(msg, len);
			}
			virtual void onRecv(const char *msg, int len) override {
				if (m_msgHandler != nullptr) {
					m_msgHandler(this, msg, len);
				} else {
					onMessage(msg, len);
				}
			}
			virtual void onTerminate() override {
				m_closed = true;
				LogAinfo("{} {}:{} disconnected", m_id, m_conn->getRemoteAddr(), m_conn->getRemotePort());
				if (m_statusHandler != nullptr) {
					m_statusHandler(this, false);
				}
			}
			virtual void onConnected(connSharePtr conn) override {
				conn->setLinger(0);
				m_conn = conn;
				LogAinfo("{} {}:{} connected", m_id, conn->getRemoteAddr(), conn->getRemotePort());
				m_closed = false;
				if (m_statusHandler != nullptr) {
					m_statusHandler(this, true);
				}
			}
			virtual void release() override {
				delete this;
			}

			bool isClosed() const {
				return m_closed;
			}

			// session id.
			void setId(long long id) {
				m_id = id;
			}
			long long getId() const {
				return m_id;
			}

			void setMessageHandler(MessageHandler handler) {
				m_msgHandler = std::move(handler);
			}
			void setStatusHandler(StatusHandler& statusHandler) {
				m_statusHandler = std::move(statusHandler);
			}
			void setLinger(int sec) {
				if (isClosed()) return;
				m_conn->setLinger(sec);
			}

		public:
			// send sends binary data
			void send(const char* msg, size_t len) {
				if (isClosed()) return;
				m_conn->send(msg, len);
			}

			// sendMsg sends binary data.
			bool sendMsg(const std::string &msg) {
				if (msg.length() > gMaxPacketSize) {
					return false;
				}
				char vecBuff[gMaxPacketSize];
				char *buff = &vecBuff[0];

				SCommonHead &head = *(SCommonHead*)buff;
				head.len = static_cast<uint32>(htonl(long(msg.length())));
				memcpy(buff + gProto_head_size, msg.c_str(), msg.length());
				this->send(buff, gProto_head_size + msg.length());
				return true;
			}

			// sendMsg sends message with big codec.
			bool sendMsg(uint16 msgId, const std::string &data) {
				if (sizeof(uint16) + data.length() > gMaxPacketSize) {
					return false;
				}

				char vecBuff[gMaxPacketSize];
				auto size = buildProto(msgId, data, vecBuff, true);
				this->send(&vecBuff[0], size);
				return true;
			}

			// send_msg sends message with little codec.
			bool send_msg(uint16 msgId, const std::string &data) {
				if (sizeof(msgId) + data.length() > gMaxPacketSize) {
					return false;
				}
				char vecBuff[gMaxPacketSize];
				auto size = buildProto(msgId, data, vecBuff, false);
				this->send(&vecBuff[0], size);
				return true;
			}

			// close connection.
			void close() {
				if (isClosed()) return;
				m_conn->close();
			}

		private:
			// build protocol data big endian or little endian.
			template<int N>
			int buildProto(uint16 msgId, const std::string& data,
				char(&buff)[N], bool bigEndian) {

				// build packet header.
				SCommonHead& head = *(SCommonHead*)buff;
				if (bigEndian) {
					head.len = uint32(htonl(uint32(data.length()) + sizeof(msgId)));
					// message id.
					*(uint16*)(buff + gProto_head_size) = htons(msgId);
				}
				else {
					head.len = uint32(data.length()) + sizeof(msgId);
					// message id.
					*(uint16*)(buff + gProto_head_size) = msgId;
				}

				// copy message.
				memcpy(buff + gProto_head_size + sizeof(msgId), data.c_str(), data.length());
				return int(gProto_head_size + sizeof(msgId) + data.length());
			}

		protected:
			connSharePtr m_conn{ nullptr };

			// close flag.
			std::atomic_bool m_closed{ true };
		
			// session id.
			long long m_id{ 0 };
			
			// message handler
			MessageHandler m_msgHandler{ nullptr };

			// status handler 
			StatusHandler m_statusHandler{ nullptr };
			static const int gMaxPacketSize = 10 * 1024;
		};
		// ============================

		// session factory.
		class CSessionFactory final : public ISessionFactory {
		public:
			CSessionFactory() {}
			virtual ~CSessionFactory() = default;
			CSessionFactory& operator=(const CSessionFactory& rhs) = delete;
			CSessionFactory(const CSessionFactory& rhs) = delete;

		public:
			// try to use CSession pool.
			virtual ISession *createSession() override {
				return new CSession();
			}
		};

		// template session factory.
		template <typename session>
		class CTemplateSessionFactory final : public ISessionFactory {
		public:
			CTemplateSessionFactory() = default;
			virtual ~CTemplateSessionFactory() = default;
			CTemplateSessionFactory& operator=(const CTemplateSessionFactory& rhs) = delete;
			CTemplateSessionFactory(const CTemplateSessionFactory& rhs) = delete;

		public:
			// create session interface.
			virtual ISession *createSession() override {
				return new session();
			}

			// release session interface.
			void releaseSession(session *pSession) {
				delete pSession;
			}
		};
	}
}
